﻿using System.Transactions;

namespace HZY.Framework.Repository.EntityFramework.Interceptor;

/// <summary>
/// 读写分离拦截器
/// </summary>
public class ReadWriteCommandInterceptor : DbCommandInterceptor
{
    #region 同步方法

    public override InterceptionResult<DbDataReader> ReaderExecuting(
        DbCommand command,
        CommandEventData eventData,
        InterceptionResult<DbDataReader> result)
    {
        // 在执行读操作之前，修改连接字符串为读库的连接字符串
        UseSlaveDatabase(command, eventData);
        return base.ReaderExecuting(command, eventData, result);
    }

    public override InterceptionResult<int> NonQueryExecuting(
        DbCommand command,
        CommandEventData eventData,
        InterceptionResult<int> result)
    {
        // 在执行写操作之前，修改连接字符串为写库的连接字符串
        UseSlaveDatabase(command, eventData);
        return base.NonQueryExecuting(command, eventData, result);
    }

    public override InterceptionResult<object> ScalarExecuting(DbCommand command, CommandEventData eventData, InterceptionResult<object> result)
    {
        // 在执行写操作之前，修改连接字符串为写库的连接字符串
        UseSlaveDatabase(command, eventData);
        return base.ScalarExecuting(command, eventData, result);
    }

    #endregion

    #region 异步方法

    public override ValueTask<InterceptionResult<DbDataReader>> ReaderExecutingAsync(DbCommand command, CommandEventData eventData, InterceptionResult<DbDataReader> result, CancellationToken cancellationToken = default)
    {
        UseSlaveDatabase(command, eventData);
        return base.ReaderExecutingAsync(command, eventData, result, cancellationToken);
    }

    public override ValueTask<InterceptionResult<int>> NonQueryExecutingAsync(DbCommand command, CommandEventData eventData, InterceptionResult<int> result, CancellationToken cancellationToken = default)
    {
        UseSlaveDatabase(command, eventData);
        return base.NonQueryExecutingAsync(command, eventData, result, cancellationToken);
    }

    public override ValueTask<InterceptionResult<object>> ScalarExecutingAsync(DbCommand command, CommandEventData eventData, InterceptionResult<object> result, CancellationToken cancellationToken = default)
    {
        UseSlaveDatabase(command, eventData);
        return base.ScalarExecutingAsync(command, eventData, result, cancellationToken);
    }

    #endregion

    /// <summary>
    /// 在执行写操作之前，修改连接字符串为写库的连接字符串 为从库的连接字符串
    /// </summary>
    /// <param name="command"></param>
    /// <param name="eventData"></param>
    private static void UseSlaveDatabase(DbCommand command, CommandEventData eventData)
    {
        if (command.Connection is null)
        {
            return;
        }

        if (eventData.Context is null)
        {
            return;
        }

        // 定义 插入，修改，删除 不走 读写分离
        var sqlFragment = new string[] { "insert", "update", "delete" };
        if (sqlFragment.Any(f => command.CommandText.ToLower().StartsWith(f, StringComparison.InvariantCultureIgnoreCase)))
        {
            return;
        }

        if (eventData.Context is not IBaseDbContext context)
        {
            return;
        }

        var slaveConnectionStrings = context.GetRepositoryOptions()?.SlaveConnectionStrings;
        if (slaveConnectionStrings is null || slaveConnectionStrings.Length == 0)
        {
            return;
        }

        // 随机取 从库连接字符串
        var slaveConnectionString = slaveConnectionStrings.Length == 1 ?
            slaveConnectionStrings[0] :
            slaveConnectionStrings[new Random().Next(0, slaveConnectionStrings.Length)];
        if (string.IsNullOrWhiteSpace(slaveConnectionString))
        {
            return;
        }

        // 判断当前会话是否处于分布式事务中
        bool isDistributedTran = Transaction.Current != null &&
                                     Transaction.Current.TransactionInformation.Status != TransactionStatus.Committed;
        // 如果处于分布式事务或普通事务中，则“禁用”读写分离，处于事务中的所有读写操作都指向 Master
        if (isDistributedTran || eventData.Context.Database.CurrentTransaction is not null)
        {
            return;
        }

        if (command.Connection.State == ConnectionState.Open)
        {
            command.Connection.Close();
        }

        command.Connection.ConnectionString = slaveConnectionString;

        if (command.Connection.State != ConnectionState.Open)
        {
            command.Connection.Open();
        }
    }

}
